import os
import torch
import random

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
import pandas as pd
import torch.nn.functional as F
import sys
from datasets import Dataset
from transformers.trainer_callback import EarlyStoppingCallback
import pickle
from transformers import TrainerCallback



seed_val = 42
random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)


# The model that you want to train from the Hugging Face hub
model_name = "Llama-2-7b-hf"

################################################################################
# QLoRA parameters
################################################################################

# LoRA attention dimension
lora_r = 2

# Alpha parameter for LoRA scaling
lora_alpha = 8

# Dropout probability for LoRA layers
lora_dropout = 0.1

################################################################################
# bitsandbytes parameters
################################################################################

# Activate 4-bit precision base model loading
use_4bit = True

# Compute dtype for 4-bit base models
bnb_4bit_compute_dtype = "float16"

# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"

# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = False

################################################################################
# TrainingArguments parameters
################################################################################

# Output directory where the model predictions and checkpoints will be stored
output_dir = "./results"

# Number of training epochs
num_train_epochs = 1

# Enable fp16/bf16 training (set bf16 to True with an A100)
fp16 = False
bf16 = False

# Batch size per GPU for training
per_device_train_batch_size = 1

# Batch size per GPU for evaluation
per_device_eval_batch_size = 1

# Number of update steps to accumulate the gradients for
gradient_accumulation_steps = 1

# Enable gradient checkpointing
gradient_checkpointing = True

# Maximum gradient normal (gradient clipping)
max_grad_norm = 0.3

# Initial learning rate (AdamW optimizer)
learning_rate = 2e-4

# Weight decay to apply to all layers except bias/LayerNorm weights
weight_decay = 0.001

# Optimizer to use
optim = "paged_adamw_32bit"

# Learning rate schedule
lr_scheduler_type = "cosine"

# Number of training steps (overrides num_train_epochs)
max_steps = -1

# Ratio of steps for a linear warmup (from 0 to learning rate)
warmup_ratio = 0.03

# Group sequences into batches with same length
# Saves memory and speeds up training considerably
group_by_length = True

# Save checkpoint every X updates steps
save_steps = 0

# Log every X updates steps
logging_steps = 25

################################################################################
# SFT parameters
################################################################################

# Maximum sequence length to use
max_seq_length = None

# Pack multiple short examples in the same input sequence to increase efficiency
packing = False

# Load the entire model on the GPU 0
device_map = {"": 0}

def cross_entropy_evaluation(model, tokenizer, test_list, bos="", eos=""):
    model.eval()
    total_eval_loss = 0
    tokenized_data = tokenizer([bos+txt+eos for txt in test_list])
    inputs = [(torch.tensor(tokenized_data['input_ids'][i]), torch.tensor(tokenized_data['attention_mask'][i])) for i in range(len(test_list))]
#    print(inputs)
    for b_input_ids, b_masks in inputs:
            b_input_ids = b_input_ids.unsqueeze(0)
            b_masks = b_masks.unsqueeze(0)

            b_labels = b_input_ids

            outputs  = model(b_input_ids,
                            attention_mask = b_masks,
                            labels=b_labels)

            logits = outputs.logits[:, :-1, :]

            labels = b_input_ids[:, 1:].contiguous()
#            print(labels)
#            print(logits.shape)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))

#            loss = outputs[0]
#            print(loss)

            batch_loss = torch.mean(loss).item()
            total_eval_loss += batch_loss
    avg_loss = total_eval_loss / len(test_list)
    return avg_loss

class LogToFileCallback(TrainerCallback):
    def __init__(self, log_dir, log_file_name):
        super(LogToFileCallback, self).__init__()
        self.log_dir = log_dir
        self.log_file_name = log_file_name
        os.makedirs(log_dir, exist_ok=True)

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            with open(os.path.join(self.log_dir, self.log_file_name), "a") as f:
                f.write(str(logs) + "\n")


def finetune(file_path, file_name, output_folder, model, tokenizer):
    log = ""
    print("Now training: ", file_name)

    peft_config = LoraConfig(
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        r=lora_r,
        bias="none",
        task_type="CAUSAL_LM",
    )

    output_folder = os.path.join(output_folder, file_name)
    if not os.path.exists(output_folder):
        os.mkdir(output_folder)
    
    train_corpus = open(os.path.join(file_path, 'train.txt'), 'r').read().split('\n\n')[:-1]
    val_corpus = open(os.path.join(file_path, 'val.txt'), 'r').read().split('\n\n')[:-1]
    test_corpus = open(os.path.join(file_path, 'test.txt'), 'r').read().split('\n\n')[:-1]

    train_dataset = Dataset.from_dict({"text": train_corpus})
    val_dataset = Dataset.from_dict({"text": val_corpus})
    test_dataset = Dataset.from_dict({"text": test_corpus})

    model.train()
    training_arguments = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        optim=optim,
        save_steps=save_steps,
        logging_steps=logging_steps,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        fp16=fp16,
        bf16=bf16,
        max_grad_norm=max_grad_norm,
        max_steps=max_steps,
        warmup_ratio=warmup_ratio,
        group_by_length=group_by_length,
        lr_scheduler_type=lr_scheduler_type,
        evaluation_strategy="steps",  # Evaluate every `eval_steps`.
        eval_steps=0.25/num_train_epochs,  # Number of update steps between two evaluations.
        load_best_model_at_end=True,  # Load the best model found during training at the end of training
        metric_for_best_model="eval_loss",  # Use evaluation loss for early stopping
        greater_is_better=False,  # Smaller evaluation loss is better
        logging_dir=output_folder
    )

    early_stopping = EarlyStoppingCallback(early_stopping_patience=1)

    # Set supervised fine-tuning parameters
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        peft_config=peft_config,
        dataset_text_field="text",
        max_seq_length=max_seq_length,
        tokenizer=tokenizer,
        args=training_arguments,
        callbacks=[early_stopping, LogToFileCallback(training_arguments.logging_dir, 'llama_log.txt')],  # Add early stopping callback
        packing=packing,
    )

    trainer.train()

    diff_list = {}
    for name, param in model.named_parameters():
        if 'lora' in name:
            diff_list[name] = param.clone()

    # Save the list of tensors to disk
    with open(os.path.join(output_folder, file_name+'_llamadelta.pkl'), 'wb') as f:
        pickle.dump(diff_list, f)

    model.eval()

    with torch.no_grad():
        avg_test_loss = cross_entropy_evaluation(model, tokenizer, test_corpus)

    log+= "Test Loss: {0:.2f}".format(avg_test_loss) + '\n'
    
    with open(os.path.join(output_folder, 'llama_log.txt'), 'a') as file:
        file.write(log)

def main(data_folder, output_folder):
    data_file_names = os.listdir(data_folder)

    # Load tokenizer and model with QLoRA configuration
    compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=use_4bit,
        bnb_4bit_quant_type=bnb_4bit_quant_type,
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=use_nested_quant,
    )

    # Check GPU compatibility with bfloat16
    if compute_dtype == torch.float16 and use_4bit:
        major, _ = torch.cuda.get_device_capability()
        if major >= 8:
            print("=" * 80)
            print("Your GPU supports bfloat16: accelerate training with bf16=True")
            print("=" * 80)

    # Load base model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        device_map=device_map
    )
    model.config.use_cache = False
    model.config.pretraining_tp = 1

    # Load LLaMA tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

    if not os.path.exists(output_folder):
        os.mkdir(output_folder)
    for file_name in data_file_names:
        file_path = os.path.join(data_folder, file_name) 
        finetune(file_path, file_name, output_folder, model, tokenizer)

if __name__ == "__main__":
    data_folder = sys.argv[1]
    output_folder = sys.argv[2]
    main(data_folder, output_folder)
